﻿using DotNetCommon;
using DotNetCommon.Data;
using DotNetCommon.Extensions;
using DotNetCommon.Logger;
using System;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;

namespace DBUtil;

public interface ILocker
{
    void RunInLock(DBAccess db, string lock_str, Action action, int getLockTimeoutSecond = 60 * 3);
    T RunInLock<T>(DBAccess db, string lock_str, Func<T> func, int getLockTimeoutSecond = 60 * 3);
    Task RunInLockAsync(DBAccess db, string lock_str, Func<Task> func, int getLockTimeoutSecond = 60 * 3);
    Task<T> RunInLockAsync<T>(DBAccess db, string lock_str, Func<Task<T>> func, int getLockTimeoutSecond = 60 * 3);
}

public class CommonDBLocker(string dBLockTableName) : ILocker
{
    private static readonly ConcurrentDictionary<string, bool> _lockcaches = [];
    private static readonly ConcurrentDictionary<string, SemaphoreSlim> _lockSemaphores = [];
    private readonly ILogger<CommonDBLocker> logger = LoggerFactory.CreateLogger<CommonDBLocker>();

    public string DBLockTableName { get; private set; } = dBLockTableName;

    private async Task EnsureInitLockAsync(DBAccess db)
    {
        if (!_lockcaches.ContainsKey(db.DBConn))
        {
            await AsyncLocker.LockAsync($"EnsureInitLockAsync_{db.DBConn}", async () =>
            {
                if (!_lockcaches.ContainsKey(db.DBConn))
                {
                    if (!await db.IsTableOrViewExistAsync(DBLockTableName))
                    {
                        #region 新建表 DBLockTableName
                        db.ExecuteSql($@"
create table {DBLockTableName}(
	lock_str varchar(200) unique not null,
	lock_time {db.DateTimeSqlSegment.DefaultDateTimeType} default({db.DateTimeSqlSegment.Current}),
    lock_createtime {db.DateTimeSqlSegment.DefaultDateTimeType} default({db.DateTimeSqlSegment.Current}),
	lock_user varchar(200) not null
)");
                        #endregion
                    }
                    _lockcaches.TryAdd(db.DBConn, true);
                }
            });
        }
    }

    private async Task<Result> TryGetLock(DBAccess db, DateTime starttime, string lock_str, string lock_user, int timeoutSecond = 60 * 5)
    {
        await EnsureInitLockAsync(db);
        while (true)
        {
            try
            {
                await db.ExecuteSqlAsync($"insert into {DBLockTableName}(lock_str,lock_user) values('{lock_str}','{lock_user}')");
                StartLockMonitor(db, lock_str, lock_user);
                return Result.Ok();
            }
            catch (Exception ex)
            {
                var endtime = DateTime.Now;
                var span = endtime - starttime;
                if (span.TotalSeconds > timeoutSecond)
                {
                    return Result.NotOk($"获取锁[{lock_str}]超时:{ex?.Message}]");
                }
                await Task.Delay(500);
                //删除20秒之前的,20秒内已获取锁的程序肯定会更新时间
                try
                {
                    await db.ExecuteSqlAsync($"delete from {DBLockTableName} where lock_time <{db.DateTimeSqlSegment.GetCurrentAddSecond(-20)}");
                }
                catch (Exception ex2)
                {
                    return Result.NotOk($"获取锁[{lock_str}]失败:{ex2?.Message}]");
                }
            }
        }
    }

    private void StartLockMonitor(DBAccess db, string lock_str, string lock_user)
    {
        Task.Run(async () =>
        {
            while (true)
            {
                //分布式时 休息5秒
                await Task.Delay(5 * 1000);
                //将当前的锁刷新时长
                try
                {
                    int count = await db.ExecuteSqlAsync($"update {DBLockTableName} set lock_time= {db.DateTimeSqlSegment.Current} where lock_str='{lock_str}' and lock_user='{lock_user}'");
                    if (count != 1) break;
                }
                catch (Exception ex)
                {
                    logger.LogError($"刷新锁时间失败(lock_str={lock_str},lock_user={lock_user}):{ex?.Message}");
                    break;
                }
            }
        });
    }

    private async Task TryReleaseLock(DBAccess db, string lock_str, string lock_user)
    {
        await EnsureInitLockAsync(db);
        await db.ExecuteSqlAsync($"delete from {DBLockTableName} where lock_str='{lock_str}' and lock_user='{lock_user}'");
    }

    public void RunInLock(DBAccess db, string lock_str, Action action, int getLockTimeoutSecond = 60 * 3)
    {
        AssertUtil.NotNull(lock_str);
        AssertUtil.NotNull(action);
        var task = RunInLockInternalAsync(db, lock_str, () =>
        {
            action();
            return Task.FromResult(1);
        }, getLockTimeoutSecond);
        task.Wait();
    }

    public T RunInLock<T>(DBAccess db, string lock_str, Func<T> func, int getLockTimeoutSecond = 60 * 3)
    {
        AssertUtil.NotNull(lock_str);
        AssertUtil.NotNull(func);
        var task = RunInLockInternalAsync(db, lock_str, () =>
        {
            var res = func();
            return Task.FromResult(res);
        }, getLockTimeoutSecond);
        return task.Result;
    }
    public async Task RunInLockAsync(DBAccess db, string lock_str, Func<Task> func, int getLockTimeoutSecond = 60 * 3)
    {
        AssertUtil.NotNull(lock_str);
        AssertUtil.NotNull(func);
        await RunInLockInternalAsync(db, lock_str, async () =>
        {
            await func();
            return Task.CompletedTask;
        }, getLockTimeoutSecond);
    }

    public async Task<T> RunInLockAsync<T>(DBAccess db, string lock_str, Func<Task<T>> func, int getLockTimeoutSecond = 60 * 3)
    {
        AssertUtil.NotNull(lock_str);
        AssertUtil.NotNull(func);
        return await db.RunInNoSessionAsync(async () => await RunInLockInternalAsync(db, lock_str, func, getLockTimeoutSecond));
    }

    private async Task<T> RunInLockInternalAsync<T>(DBAccess db, string lock_str, Func<Task<T>> func, int getLockTimeoutSecond)
    {
        //先单机拦截并发
        var now = DateTime.Now;
        var asyncLock = _lockSemaphores.GetOrAdd(lock_str, _ => new SemaphoreSlim(1, 1));
        var b = await asyncLock.WaitAsync(getLockTimeoutSecond * 1000);
        if (!b) throw new Exception($"获取锁[{lock_str}]超时,单机内超时.");
        try
        {
            //单机通过后,再数据库锁内拦截
            var lock_user = $"{now.ToCommonStampString()} {Guid.NewGuid().ToString().Replace("-", "").ToLower()}";
            var res = await TryGetLock(db, now, lock_str, lock_user, getLockTimeoutSecond);
            if (res.Success)
            {
                try
                {
                    var result = await func();
                    return result;
                }
                finally
                {
                    await TryReleaseLock(db, lock_str, lock_user);
                }
            }
            else
            {
                throw new Exception(res.Message);
            }
        }
        finally
        {
            try { asyncLock.Release(); } catch { }
        }
    }
}
